{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create MNIST Petastorm Dataset\n",
    "\n",
    "In this notebook we will go through how you can create a Petastorm dataset with the MNIST images of digits, and also how you can save it as a documented and reusable training dataset in the Hopsworks Feature Store. The petastorm dataset can later on be used to train models using either Tensorflow, PyTorch or SparkML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Spark application\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tr><th>ID</th><th>YARN Application ID</th><th>Kind</th><th>State</th><th>Spark UI</th><th>Driver log</th><th>Current session?</th></tr><tr><td>1</td><td>application_1551196216588_0003</td><td>pyspark</td><td>idle</td><td><a target=\"_blank\" href=\"http://hopsworks0.logicalclocks.com:8088/proxy/application_1551196216588_0003/\">Link</a></td><td><a target=\"_blank\" href=\"http://hopsworks0.logicalclocks.com:8042/node/containerlogs/container_e01_1551196216588_0003_01_000001/demo_featurestore_admin000__meb10000\">Link</a></td><td>✔</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SparkSession available as 'spark'.\n"
     ]
    }
   ],
   "source": [
    "from hops import hdfs, featurestore\n",
    "import numpy as np\n",
    "import pydoop\n",
    "import gzip\n",
    "from tempfile import TemporaryFile\n",
    "\n",
    "# IMPORTANT: must import  tensorflow before petastorm.tf_utils due to a bug in petastorm\n",
    "import tensorflow as tf\n",
    "from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField\n",
    "from pyspark.sql.types import StructType, StructField, IntegerType\n",
    "from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Create a Hopsworks Dataset called \"mnist\" and Upload the MNIST images to Hopsworks\n",
    "\n",
    "The MNIST dataset can be downloaded here: http://yann.lecun.com/exdb/mnist/, it consists of four files: \n",
    "\n",
    "- train-images-idx3-ubyte.gz:  training set images (9912422 bytes) \n",
    "- train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) \n",
    "- t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes) \n",
    "- t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)\n",
    "\n",
    "The data is stored in the idx format. The IDX file format is a simple format for vectors and multidimensional matrices of various numerical types.\n",
    "\n",
    "The basic format according to http://yann.lecun.com/exdb/mnist/ is:\n",
    "```\n",
    "magic number\n",
    "size in dimension 1\n",
    "size in dimension 2\n",
    "size in dimension 3\n",
    "....\n",
    "size in dimension N\n",
    "data\n",
    "```\n",
    "The magic number is four bytes long. The first 2 bytes are always 0.\n",
    "\n",
    "The third byte codes the type of the data:\n",
    "```\n",
    "0x08: unsigned byte\n",
    "0x09: signed byte\n",
    "0x0B: short (2 bytes)\n",
    "0x0C: int (4 bytes)\n",
    "0x0D: float (4 bytes)\n",
    "0x0E: double (8 bytes)\n",
    "```\n",
    "The fouth byte codes the number of dimensions of the vector/matrix: 1 for vectors, 2 for matrices....\n",
    "\n",
    "The sizes in each dimension are 4-byte integers (big endian, like in most non-Intel processors).\n",
    "\n",
    "The data is stored like in a C array, i.e. the index in the last dimension changes the fastest.\n",
    "\n",
    "You could also have uploaded the .png images directly and read them using Spark: https://databricks.com/blog/2018/12/10/introducing-built-in-image-data-source-in-apache-spark-2-4.html \n",
    "\n",
    "But to save us the effort of uploading 70000 images and unzipping them, we will use Yann LeCunns IDX files. \n",
    "\n",
    "To upload the files, (1) download them from here: http://yann.lecun.com/exdb/mnist/; (2) Go to the dataset browser in Hopsworks in your project and click the button 'Create new dataset' and name the new dataset \"mnist\"; (3) Click on the dataset and upload the files.  \n",
    "\n",
    "![Petastorm 3](./../images/petastorm3.png \"Petastorm 3\")\n",
    "\n",
    "![Petastorm 4](./../images/petastorm4.png \"Petastorm 4\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hdfs://10.0.2.15:8020/Projects/demo_featurestore_admin000/mnist/README.md\n",
      "hdfs://10.0.2.15:8020/Projects/demo_featurestore_admin000/mnist/t10k-images-idx3-ubyte.gz\n",
      "hdfs://10.0.2.15:8020/Projects/demo_featurestore_admin000/mnist/t10k-labels-idx1-ubyte.gz\n",
      "hdfs://10.0.2.15:8020/Projects/demo_featurestore_admin000/mnist/train-images-idx3-ubyte.gz\n",
      "hdfs://10.0.2.15:8020/Projects/demo_featurestore_admin000/mnist/train-labels-idx1-ubyte.gz"
     ]
    }
   ],
   "source": [
    "files = hdfs.ls(hdfs.project_path() + \"mnist\")\n",
    "for i in range(len(files)):\n",
    "    print(files[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Parse the IDX files into Numpy Arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_mnist(train_num_examples, test_num_examples):\n",
    "    \"\"\"\n",
    "    Parses MNIST dataset in the IDX format into numpy multi-dimensional arrays\n",
    "    \n",
    "    Args:\n",
    "        :train_num_examples: number of training examples to parse\n",
    "        :test_num_examples: number of test examples to parse\n",
    "    \n",
    "    Returns: train_images, train_labels, test_images, test_labels\n",
    "    \"\"\"\n",
    "    image_size = 28\n",
    "    train_num_examples = train_num_examples\n",
    "    test_num_examples = test_num_examples\n",
    "    train_images_path = hdfs.project_path() + \"mnist/train-images-idx3-ubyte.gz\"\n",
    "    train_labels_path = hdfs.project_path() + \"mnist/train-labels-idx1-ubyte.gz\"\n",
    "    test_images_path = hdfs.project_path() + \"mnist/t10k-images-idx3-ubyte.gz\"\n",
    "    test_labels_path = hdfs.project_path() + \"mnist/t10k-labels-idx1-ubyte.gz\"\n",
    "    image_padding = 16\n",
    "    label_padding = 8\n",
    "    def parse_zip(path, padding, bytes):\n",
    "        f_zip = pydoop.hdfs.open(path, \"rb\")\n",
    "        f_unzip = gzip.GzipFile(mode='rb', fileobj=f_zip)\n",
    "        f_unzip.read(padding) # read metadata string\n",
    "        return f_unzip.read(bytes)\n",
    "    train_images_data = np.frombuffer(parse_zip(train_images_path, image_padding, image_size * image_size * train_num_examples), dtype=np.uint8)\n",
    "    train_images = train_images_data.reshape(train_num_examples, image_size, image_size, 1)\n",
    "    train_labels_data = np.frombuffer(parse_zip(train_labels_path, label_padding, train_num_examples), dtype=np.uint8).astype(np.int64)\n",
    "    train_labels = train_labels_data.reshape(train_num_examples, 1)\n",
    "    test_images_data = np.frombuffer(parse_zip(test_images_path, image_padding, image_size * image_size * test_num_examples), dtype=np.uint8)\n",
    "    test_images = test_images_data.reshape(test_num_examples, image_size, image_size, 1)\n",
    "    test_labels_data = np.frombuffer(parse_zip(test_labels_path, label_padding, test_num_examples), dtype=np.uint8).astype(np.int64)\n",
    "    test_labels = test_labels_data.reshape(test_num_examples, 1)\n",
    "    return train_images, train_labels, test_images, test_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_images, train_labels, test_images, test_labels = read_mnist(60000, 10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Quick Data Validation\n",
    "\n",
    "To validate that we parsed the data correctly we can try to visualize a bunch of images and compare the imaegs to their labels. Plotting in a distributed computing setting is a little bit tricky, it is explained in more details here: https://hopsworks.readthedocs.io/en/0.9/user_guide/hopsworks/jupyter.html#plotting-with-pyspark-kernel.\n",
    "\n",
    "TLDR; We can save a few images to HDFS and read them in the %%local environment to plot them"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Save a handful of images/labels to HDFS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hops import numpy_helper as hops_numpy\n",
    "hops_numpy.save(\"mnist/train_img_val.npy\", train_images[0:2])\n",
    "hops_numpy.save(\"mnist/train_lbl_val.npy\", train_labels[0:2])\n",
    "hops_numpy.save(\"mnist/test_img_val.npy\", test_images[0:2])\n",
    "hops_numpy.save(\"mnist/test_lbl_val.npy\", test_labels[0:2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Read sample images from HDFS in %%local and plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Utility functions for plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%local\n",
    "import matplotlib.pyplot as plt\n",
    "from hops import hdfs\n",
    "import numpy as np\n",
    "from tempfile import TemporaryFile\n",
    "from hops import numpy_helper as hops_numpy\n",
    "\n",
    "def plot_train_example(example_num):\n",
    "    \"\"\"\n",
    "    Utility function that plots a training image together with its label \n",
    "    \n",
    "    Args:\n",
    "         :example_num: the id of the example to plot\n",
    "    \n",
    "    Returns: None\n",
    "    \"\"\"\n",
    "    image = np.asarray(train_images_val[example_num]).squeeze()\n",
    "    plt.imshow(image)\n",
    "    print(\"label:\" + str(train_lbls_val[example_num][0]))\n",
    "    \n",
    "def plot_test_example(example_num):\n",
    "    \"\"\"\n",
    "    Utility function that plots a test image together with its label\n",
    "\n",
    "    Args:\n",
    "         :example_num: the id of the example to plot\n",
    "    \n",
    "    Returns: None\n",
    "    \"\"\"\n",
    "    image = np.asarray(test_images_val[example_num]).squeeze()\n",
    "    plt.imshow(image)\n",
    "    print(\"label:\" + str(test_lbls_val[example_num][0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Read images in %%local"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%local\n",
    "train_images_val = hops_numpy.load(\"mnist/train_img_val.npy\")\n",
    "train_lbls_val = hops_numpy.load(\"mnist/train_lbl_val.npy\")\n",
    "test_images_val = hops_numpy.load(\"mnist/test_img_val.npy\")\n",
    "test_lbls_val = hops_numpy.load(\"mnist/test_lbl_val.npy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Plot using matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label:5\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADoBJREFUeJzt3X2MXOV1x/HfyXq9jo1JvHHYboiLHeMEiGlMOjIgLKCiuA5CMiiKiRVFDiFxmuCktK4EdavGrWjlVgmRQynS0ri2I95CAsJ/0CR0FUGiwpbFMeYtvJlNY7PsYjZgQ4i9Xp/+sdfRBnaeWc/cmTu75/uRVjtzz71zj6792zszz8x9zN0FIJ53Fd0AgGIQfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQU1r5M6mW5vP0KxG7hII5bd6U4f9kE1k3ZrCb2YrJG2W1CLpP9x9U2r9GZqls+2iWnYJIKHHuye8btVP+82sRdJNkj4h6QxJq83sjGofD0Bj1fKaf6mk5919j7sflnSHpJX5tAWg3moJ/8mSfjXm/t5s2e8xs7Vm1mtmvcM6VMPuAOSp7u/2u3uXu5fcvdSqtnrvDsAE1RL+fZLmjbn/wWwZgEmglvA/ImmRmS0ws+mSPi1pRz5tAai3qof63P2Ima2T9CONDvVtcfcnc+sMQF3VNM7v7vdJui+nXgA0EB/vBYIi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+IKiaZuk1sz5JByWNSDri7qU8mkJ+bFr6n7jl/XPruv9n/np+2drIzKPJbU9ZOJisz/yKJesv3zC9bG1n6c7ktvtH3kzWz75rfbJ+6l89nKw3g5rCn/kTd9+fw+MAaCCe9gNB1Rp+l/RjM3vUzNbm0RCAxqj1af8yd99nZidJut/MfuHuD45dIfujsFaSZmhmjbsDkJeazvzuvi/7PSjpHklLx1mny91L7l5qVVstuwOQo6rDb2azzGz2sduSlkt6Iq/GANRXLU/7OyTdY2bHHuc2d/9hLl0BqLuqw+/ueyR9LMdepqyW0xcl697Wmqy/dMF7k/W3zik/Jt3+nvR49U8/lh7vLtJ//WZ2sv4v/7YiWe8587aytReH30puu2ng4mT9Az/1ZH0yYKgPCIrwA0ERfiAowg8ERfiBoAg/EFQe3+oLb+TCjyfrN2y9KVn/cGv5r55OZcM+kqz//Y2fS9anvZkebjv3rnVla7P3HUlu27Y/PRQ4s7cnWZ8MOPMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8+eg7ZmXkvVHfzsvWf9w60Ce7eRqff85yfqeN9KX/t668Ptla68fTY/Td3z7f5L1epr8X9itjDM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRl7o0b0TzR2v1su6hh+2sWQ1eem6wfWJG+vHbL7hOS9ce+cuNx93TM9fv/KFl/5IL0OP7Ia68n635u+au7930tuakWrH4svQLeoce7dcCH0nOXZzjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQFcf5zWyLpEslDbr74mxZu6Q7Jc2X1Cdplbv/utLOoo7zV9Iy933J+sirQ8n6i7eVH6t/8vwtyW2X/vNXk/WTbiruO/U4fnmP82+V9PaJ0K+T1O3uiyR1Z/cBTCIVw+/uD0p6+6lnpaRt2e1tki7LuS8AdVbta/4Od+/Pbr8sqSOnfgA0SM1v+PnomwZl3zgws7Vm1mtmvcM6VOvuAOSk2vAPmFmnJGW/B8ut6O5d7l5y91Kr2qrcHYC8VRv+HZLWZLfXSLo3n3YANErF8JvZ7ZIekvQRM9trZldJ2iTpYjN7TtKfZvcBTCIVr9vv7qvLlBiwz8nI/ldr2n74wPSqt/3oZ55K1l+5uSX9AEdHqt43isUn/ICgCD8QFOEHgiL8QFCEHwiK8ANBMUX3FHD6tc+WrV15ZnpE9j9P6U7WL/jU1cn67DsfTtbRvDjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPAalpsl/98unJbf9vx1vJ+nXXb0/W/2bV5cm6//w9ZWvz/umh5LZq4PTxEXHmB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKk7RnSem6G4+Q58/N1m/9evfSNYXTJtR9b4/un1dsr7olv5k/cievqr3PVXlPUU3gCmI8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2ZbJF0qadDdF2fLNkr6oqRXstU2uPt9lXbGOP/k4+ctSdZP3LQ3Wb/9Qz+qet+n/eQLyfpH/qH8dQwkaeS5PVXve7LKe5x/q6QV4yz/lrsvyX4qBh9Ac6kYfnd/UNJQA3oB0EC1vOZfZ2a7zWyLmc3JrSMADVFt+G+WtFDSEkn9kr5ZbkUzW2tmvWbWO6xDVe4OQN6qCr+7D7j7iLsflXSLpKWJdbvcveTupVa1VdsngJxVFX4z6xxz93JJT+TTDoBGqXjpbjO7XdKFkuaa2V5JX5d0oZktkeSS+iR9qY49AqgDvs+PmrR0nJSsv3TFqWVrPdduTm77rgpPTD/z4vJk/fVlrybrUxHf5wdQEeEHgiL8QFCEHwiK8ANBEX4gKIb6UJjv7U1P0T3Tpifrv/HDyfqlX72m/GPf05PcdrJiqA9ARYQfCIrwA0ERfiAowg8ERfiBoAg/EFTF7/MjtqPL0pfufuFT6Sm6Fy/pK1urNI5fyY1DZyXrM+/trenxpzrO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8U5yVFifrz34tPdZ+y3nbkvXzZ6S/U1+LQz6crD88tCD9AEf7c+xm6uHMDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBVRznN7N5krZL6pDkkrrcfbOZtUu6U9J8SX2SVrn7r+vXalzTFpySrL9w5QfK1jZecUdy20+esL+qnvKwYaCUrD+w+Zxkfc629HX/kTaRM/8RSevd/QxJ50i62szOkHSdpG53XySpO7sPYJKoGH5373f3ndntg5KelnSypJWSjn38a5uky+rVJID8HddrfjObL+ksST2SOtz92OcnX9boywIAk8SEw29mJ0j6gaRr3P3A2JqPTvg37qR/ZrbWzHrNrHdYh2pqFkB+JhR+M2vVaPBvdfe7s8UDZtaZ1TslDY63rbt3uXvJ3UutasujZwA5qBh+MzNJ35H0tLvfMKa0Q9Ka7PYaSffm3x6AepnIV3rPk/RZSY+b2a5s2QZJmyR9z8yukvRLSavq0+LkN23+Hybrr/9xZ7J+xT/+MFn/8/fenazX0/r+9HDcQ/9efjivfev/Jredc5ShvHqqGH53/5mkcvN9X5RvOwAahU/4AUERfiAowg8ERfiBoAg/EBThB4Li0t0TNK3zD8rWhrbMSm775QUPJOurZw9U1VMe1u1blqzvvDk9Rffc7z+RrLcfZKy+WXHmB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgwozzH/6z9GWiD//lULK+4dT7ytaWv/vNqnrKy8DIW2Vr5+9Yn9z2tL/7RbLe/lp6nP5osopmxpkfCIrwA0ERfiAowg8ERfiBoAg/EBThB4IKM87fd1n679yzZ95Vt33f9NrCZH3zA8uTdRspd+X0Uadd/2LZ2qKBnuS2I8kqpjLO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl7egWzeZK2S+qQ5JK63H2zmW2U9EVJr2SrbnD38l96l3SitfvZxqzeQL30eLcO+FD6gyGZiXzI54ik9e6+08xmS3rUzO7Pat9y929U2yiA4lQMv7v3S+rPbh80s6clnVzvxgDU13G95jez+ZLOknTsM6PrzGy3mW0xszlltllrZr1m1jusQzU1CyA/Ew6/mZ0g6QeSrnH3A5JulrRQ0hKNPjP45njbuXuXu5fcvdSqthxaBpCHCYXfzFo1Gvxb3f1uSXL3AXcfcfejkm6RtLR+bQLIW8Xwm5lJ+o6kp939hjHLO8esdrmk9HStAJrKRN7tP0/SZyU9bma7smUbJK02syUaHf7rk/SlunQIoC4m8m7/zySNN26YHNMH0Nz4hB8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCoipfuznVnZq9I+uWYRXMl7W9YA8enWXtr1r4keqtWnr2d4u7vn8iKDQ3/O3Zu1uvupcIaSGjW3pq1L4neqlVUbzztB4Ii/EBQRYe/q+D9pzRrb83al0Rv1Sqkt0Jf8wMoTtFnfgAFKST8ZrbCzJ4xs+fN7LoieijHzPrM7HEz22VmvQX3ssXMBs3siTHL2s3sfjN7Lvs97jRpBfW20cz2Zcdul5ldUlBv88zsJ2b2lJk9aWZ/kS0v9Ngl+irkuDX8ab+ZtUh6VtLFkvZKekTSand/qqGNlGFmfZJK7l74mLCZnS/pDUnb3X1xtuxfJQ25+6bsD+ccd7+2SXrbKOmNomduziaU6Rw7s7SkyyR9TgUeu0Rfq1TAcSvizL9U0vPuvsfdD0u6Q9LKAvpoeu7+oKShty1eKWlbdnubRv/zNFyZ3pqCu/e7+87s9kFJx2aWLvTYJfoqRBHhP1nSr8bc36vmmvLbJf3YzB41s7VFNzOOjmzadEl6WVJHkc2Mo+LMzY30tpmlm+bYVTPjdd54w++dlrn7xyV9QtLV2dPbpuSjr9maabhmQjM3N8o4M0v/TpHHrtoZr/NWRPj3SZo35v4Hs2VNwd33Zb8HJd2j5pt9eODYJKnZ78GC+/mdZpq5ebyZpdUEx66ZZrwuIvyPSFpkZgvMbLqkT0vaUUAf72Bms7I3YmRmsyQtV/PNPrxD0prs9hpJ9xbYy+9plpmby80srYKPXdPNeO3uDf+RdIlG3/F/QdLfFtFDmb4+JOmx7OfJonuTdLtGnwYOa/S9kaskvU9St6TnJP23pPYm6u27kh6XtFujQessqLdlGn1Kv1vSruznkqKPXaKvQo4bn/ADguINPyAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQf0/sEWOix6VKakAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%local\n",
    "%matplotlib inline\n",
    "plot_train_example(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label:0\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADuNJREFUeJzt3X+QVfV5x/HPw3bll+hIDBtCSIkKUkobiBuMjQlJrA7YTNGZhoTpGEptyUyixWjbOLYzddKZDs2YWNNgUhKJmB+YzqiR6VCjbplaE0JYkIiKBkOWCiJEoAV/4S779I89pBvd872Xe8+95+4+79fMzt57nnPueebCZ8+993vO/Zq7C0A8o8puAEA5CD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaB+o5k7O81G+xiNb+YugVBe08t63Y9bNevWFX4zWyDpNkltkr7h7itT64/ReF1ol9SzSwAJm72r6nVrftlvZm2SVklaKGmWpCVmNqvWxwPQXPW8558n6Vl33+3ur0u6W9KiYtoC0Gj1hH+KpOcG3d+bLfs1ZrbczLrNrLtXx+vYHYAiNfzTfndf7e6d7t7ZrtGN3h2AKtUT/n2Spg66/45sGYBhoJ7wb5E03czeZWanSfqEpPXFtAWg0Woe6nP3PjO7RtIPNDDUt8bdnyysMwANVdc4v7tvkLShoF4ANBGn9wJBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVFOn6MbI0/eRC5L1/Z/On6LtpxetTW777k1Lk/W3rzotWW/buC1Zj44jPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVdc4v5n1SDom6YSkPnfvLKIptI7++XOT9S+v+Uqyfl57/n+x/gr7fuyibybrz3SeSNb/atr7KuwhtiJO8vmwu79YwOMAaCJe9gNB1Rt+l/SgmW01s+VFNASgOep92X+xu+8zs0mSHjKzp939kcErZH8UlkvSGI2rc3cAilLXkd/d92W/D0q6T9K8IdZZ7e6d7t7ZrtH17A5AgWoOv5mNN7MJJ29LukzSE0U1BqCx6nnZ3yHpPjM7+TjfdfcHCukKQMPVHH533y3p3QX2ghL0XpY+NeOvb/9Wsj6jPX1NfX9iNH93b29y2//tT79NnFvhXeTxhe/NrY3duCO5bf9rr6UffARgqA8IivADQRF+ICjCDwRF+IGgCD8QFF/dPQK0nXFGbu3lD85MbvvZW7+brH947EsV9l778ePOI7+XrHfdflGy/sObv5ysP/SNr+XWZn37muS253xuU7I+EnDkB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOcfAfbeNSW3tuW9q5rYyan5/KQtyfoDp6fPA1jWc1myvnbaw7m1M2YdSm4bAUd+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf5hoO8jFyTr6+bkT5M9Sumv1q5k2Z5LkvXuh38rWd9xdX5vG18dk9x2UveryfqzR9LfVdD+Dxtza6MsuWkIHPmBoAg/EBThB4Ii/EBQhB8IivADQRF+IChz9/QKZmskfVTSQXefnS2bKOl7kqZJ6pG02N2PVNrZGTbRL7T0uHFE/fPnJuv/tPb2ZP289tpP1/jDp69M1tv+6OVk/fAfnJ+sH5qdP6A+Y9VzyW37ntubrFfyb/u25tb2n0ifQ/CnS/8iWW/buK2mnhpts3fpqB+u6iyGao78d0pa8IZlN0rqcvfpkrqy+wCGkYrhd/dHJB1+w+JFktZmt9dKuqLgvgA0WK3v+TvcfX92+wVJHQX1A6BJ6v7Azwc+NMj94MDMlptZt5l19+p4vbsDUJBaw3/AzCZLUvb7YN6K7r7a3TvdvbNdo2vcHYCi1Rr+9ZKWZreXSrq/mHYANEvF8JvZOkmbJJ1vZnvN7GpJKyVdama7JP1+dh/AMFJxgNjdl+SUGLCvkl3w28n6i9enx5xntKevyd+a+CjlP16aldz20N1Tk/W3HEnPU3/mt3+cridqfcktG6ujLf0W9NB1ryTrk/K/KmDY4Aw/ICjCDwRF+IGgCD8QFOEHgiL8QFB8dXcBRo0bl6z3feFosv7jmfcm67/oez1Zv/6mG3JrZ/3Xfye3nTQ+9+RMSdKJZHXkmjd5T7Le05w2GoojPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTh/AV6dn75k9wcz01+9Xcmfrfhssj7h+/mX1ZZ52SxaG0d+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf4C/O7fb0/WR1X4G7tsT/pb0Md+/yen3BOkdmvLrfWmZ6ZXm1VYYQTgyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQVUc5zezNZI+Kumgu8/Olt0s6c8l/TJb7SZ339CoJlvB/1x1UW7tbztuSW7brwpTbD+Ynkb7nfpRso6h9Xr+rAP96k9u+8DO9L/JdG2rqadWUs2R/05JC4ZYfqu7z8l+RnTwgZGoYvjd/RFJh5vQC4Amquc9/zVm9riZrTGzswrrCEBT1Br+r0o6V9IcSfslfTFvRTNbbmbdZtbdq+M17g5A0WoKv7sfcPcT7t4v6euS5iXWXe3une7e2a7RtfYJoGA1hd/MJg+6e6WkJ4ppB0CzVDPUt07ShySdbWZ7Jf2dpA+Z2RxJroHZij/VwB4BNEDF8Lv7kiEW39GAXlpa39j82pmj0uP4m15Lv905567n0/tOVkeuUePGJetP3zK7wiNsza388e6FyS1nrvhFsp5/BsHwwRl+QFCEHwiK8ANBEX4gKMIPBEX4gaD46u4mOHTi9GS9b3dPcxppMZWG8p5Z+TvJ+tOLvpKs//srZ+bWnl91XnLbCUfypz0fKTjyA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPM3wV/+8GPJ+ozEpafDXf/8ubm1g9e/mtx2Z2d6HP+SHR9P1scv2J1bm6CRP45fCUd+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf5qWX5pVIW/obddvC5ZX6UZtXTUEvZ8Pn/qckm655Nfyq3NaE9/5fl7frI0WX/7lU8l60jjyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQVUc5zezqZLuktQhySWtdvfbzGyipO9JmiapR9Jidz/SuFZL5vmlfvUnN50/9lCyft2dFyTr534z/fjtLxzLrR2Y/9bkthM/vjdZv/adXcn6wnHp7yJY/3JHbu2TOxYktz37X8Yn66hPNUf+Pkk3uPssSe+T9BkzmyXpRkld7j5dUld2H8AwUTH87r7f3bdlt49J2ilpiqRFktZmq62VdEWjmgRQvFN6z29m0yTNlbRZUoe7789KL2jgbQGAYaLq8JvZ6ZLukXSdux8dXHN3V867YjNbbmbdZtbdq+N1NQugOFWF38zaNRD877j7vdniA2Y2OatPlnRwqG3dfbW7d7p7Z7tGF9EzgAJUDL+ZmaQ7JO1098GXaK2XdPKyq6WS7i++PQCNUs0lve+XdJWkHWa2PVt2k6SVkv7VzK6WtEfS4sa0OPyNsfTTvPPSryXrj35gTLK+6/jbcmvLzuxJbluvFc9/IFl/4EdzcmvTV/D12WWqGH53f1T5V7NfUmw7AJqFM/yAoAg/EBThB4Ii/EBQhB8IivADQdnAmbnNcYZN9AtteI4Ots04N7c2Y92e5Lb/+LZNde270leDV7qkOOWx4+nHXvKfy5P1GctG7vTiw9Fm79JRP5z4ovn/x5EfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Jiiu4qnfjZz3Nruz42LbntrGuvTdafWvzPtbRUlZkbPp2sn3/7K8n6jMcYxx+pOPIDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBczw+MIFzPD6Aiwg8ERfiBoAg/EBThB4Ii/EBQhB8IqmL4zWyqmW00s6fM7EkzW5Etv9nM9pnZ9uzn8sa3C6Ao1XyZR5+kG9x9m5lNkLTVzB7Kare6+y2Naw9Ao1QMv7vvl7Q/u33MzHZKmtLoxgA01im95zezaZLmStqcLbrGzB43szVmdlbONsvNrNvMunt1vK5mARSn6vCb2emS7pF0nbsflfRVSedKmqOBVwZfHGo7d1/t7p3u3tmu0QW0DKAIVYXfzNo1EPzvuPu9kuTuB9z9hLv3S/q6pHmNaxNA0ar5tN8k3SFpp7t/adDyyYNWu1LSE8W3B6BRqvm0//2SrpK0w8y2Z8tukrTEzOZIckk9kj7VkA4BNEQ1n/Y/Kmmo64M3FN8OgGbhDD8gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQTZ2i28x+KWnPoEVnS3qxaQ2cmlbtrVX7kuitVkX29pvu/tZqVmxq+N+0c7Nud+8srYGEVu2tVfuS6K1WZfXGy34gKMIPBFV2+FeXvP+UVu2tVfuS6K1WpfRW6nt+AOUp+8gPoCSlhN/MFpjZM2b2rJndWEYPecysx8x2ZDMPd5fcyxozO2hmTwxaNtHMHjKzXdnvIadJK6m3lpi5OTGzdKnPXavNeN30l/1m1ibpZ5IulbRX0hZJS9z9qaY2ksPMeiR1unvpY8Jm9kFJL0m6y91nZ8u+IOmwu6/M/nCe5e6fa5Hebpb0UtkzN2cTykwePLO0pCsk/YlKfO4SfS1WCc9bGUf+eZKedffd7v66pLslLSqhj5bn7o9IOvyGxYskrc1ur9XAf56my+mtJbj7fnfflt0+JunkzNKlPneJvkpRRvinSHpu0P29aq0pv13Sg2a21cyWl93MEDqyadMl6QVJHWU2M4SKMzc30xtmlm6Z566WGa+Lxgd+b3axu79H0kJJn8le3rYkH3jP1krDNVXN3NwsQ8ws/StlPne1znhdtDLCv0/S1EH335Etawnuvi/7fVDSfWq92YcPnJwkNft9sOR+fqWVZm4eamZptcBz10ozXpcR/i2SppvZu8zsNEmfkLS+hD7exMzGZx/EyMzGS7pMrTf78HpJS7PbSyXdX2Ivv6ZVZm7Om1laJT93LTfjtbs3/UfS5Rr4xP/nkv6mjB5y+jpH0k+znyfL7k3SOg28DOzVwGcjV0t6i6QuSbskPSxpYgv19i1JOyQ9roGgTS6pt4s18JL+cUnbs5/Ly37uEn2V8rxxhh8QFB/4AUERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8I6v8AG8x2aarNGp8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%local\n",
    "%matplotlib inline\n",
    "plot_train_example(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label:7\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADaVJREFUeJzt3X+MXOV1xvHnib1e4jU0GILrGgcnhKA6NDjVxiSCVo4IKZAgEyWhWKrlSpRFLUhQRW2Rq6iWWqUUhSC3SSM5wY1BBGgCCCtx01CrrYVKHS/I2IBpTajT2DVewLQ2AfwDn/6x19EGdt5d5ted9fl+pNXO3HPv3KPrfXzvzDszryNCAPJ5R90NAKgH4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kNT0bu5shvvjJA10c5dAKq/rZzochzyZdVsKv+1LJa2WNE3SNyPiltL6J2lAF/jiVnYJoGBzbJz0uk1f9tueJulrki6TtFDSMtsLm308AN3VynP+xZKejYjnIuKwpHslLW1PWwA6rZXwz5P00zH3d1fLfoHtIdvDtoeP6FALuwPQTh1/tT8i1kTEYEQM9qm/07sDMEmthH+PpPlj7p9ZLQMwBbQS/i2SzrH9XtszJF0taX172gLQaU0P9UXEUds3SPpHjQ71rY2Ip9rWGYCOammcPyI2SNrQpl4AdBFv7wWSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCplmbptb1L0kFJb0g6GhGD7WgKQOe1FP7KxyPixTY8DoAu4rIfSKrV8IekH9p+zPZQOxoC0B2tXvZfFBF7bJ8h6WHbz0TEprErVP8pDEnSSZrZ4u4AtEtLZ/6I2FP9HpH0oKTF46yzJiIGI2KwT/2t7A5AGzUdftsDtk8+flvSJyU92a7GAHRWK5f9cyQ9aPv443w7In7Qlq4AdFzT4Y+I5ySd38ZeAHQRQ31AUoQfSIrwA0kRfiApwg8kRfiBpNrxqb4UXrr2Yw1r71n+bHHbZ0bmFOuHD/UV6/PuKddn7n6lYe3Y1qeL2yIvzvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBTj/JP0x3/07Ya1zw68XN747BZ3vqRc3nX01Ya11S98vMWdT10/GjmrYW3gtl8qbjt942PtbqfncOYHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQcEV3b2SmeHRf44q7tr51+9rkLGtZe/FD5/9BTd5SP8cu/6mJ9xof+t1i/9bwHGtYueedrxW2//+qsYv1TMxt/V0CrXovDxfrmQwPF+pKTjjS97/d//7pi/QNDW5p+7Dptjo06EPvLf1AVzvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kNSEn+e3vVbSpyWNRMR51bLZku6TtEDSLklXRcQEH2qf2ga+u7lQa+2xT2ltc/3NLy9pWPuLCxeU9/2v5TkHbl3y/iY6mpzprx0r1ge27S3WT9t0f7H+azMaz3cwc1d5LoQMJnPm/5akS9+07GZJGyPiHEkbq/sAppAJwx8RmyTtf9PipZLWVbfXSbqyzX0B6LBmn/PPiYjj12TPSyrPRwWg57T8gl+Mfjig4ZvXbQ/ZHrY9fESHWt0dgDZpNvz7bM+VpOr3SKMVI2JNRAxGxGCf+pvcHYB2azb86yWtqG6vkPRQe9oB0C0Tht/2PZIelXSu7d22r5F0i6RLbO+U9InqPoApZMJx/ohY1qA0NT+YfwI6+vy+hrWB+xvXJOmNCR574LsvNdFRe+z7vY8V6x+cUf7z/fL+cxvWFvzdc8VtjxarJwbe4QckRfiBpAg/kBThB5Ii/EBShB9Iiim6UZvpZ80v1r+68qvFep+nFevfWf2JhrXT9j5a3DYDzvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBTj/KjNM384r1j/SH95pumnDpenH5/99Ktvu6dMOPMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKM86OjDn3qIw1rj3/u9gm2Ls/w9Ps33lisv/PffjTB4+fGmR9IivADSRF+ICnCDyRF+IGkCD+QFOEHkppwnN/2WkmfljQSEedVy1ZJulbSC9VqKyNiQ6eaxNT135c1Pr/Mcnkcf9l/XVKsz/zBE8V6FKuYzJn/W5IuHWf57RGxqPoh+MAUM2H4I2KTpP1d6AVAF7XynP8G29tsr7V9ats6AtAVzYb/65LOlrRI0l5JtzVa0faQ7WHbw0d0qMndAWi3psIfEfsi4o2IOCbpG5IWF9ZdExGDETHYN8EHNQB0T1Phtz13zN3PSHqyPe0A6JbJDPXdI2mJpNNt75b0Z5KW2F6k0dGUXZKu62CPADpgwvBHxLJxFt/RgV4wBb3j5JOL9eW/8UjD2oFjrxe3HfnS+4r1/kNbinWU8Q4/ICnCDyRF+IGkCD+QFOEHkiL8QFJ8dTdasnPVB4v1753+tw1rS3d+trht/waG8jqJMz+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMU4P4r+73c+Wqxv++2/LtZ/fPRIw9orf3Vmcdt+7S3W0RrO/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOP8yU2f9yvF+k1fvK9Y73f5T+jqJ5Y3rL37H/i8fp048wNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUhOO89ueL+lOSXMkhaQ1EbHa9mxJ90laIGmXpKsi4uXOtYpmeHr5n/j87+0u1j8/66Vi/e6DZxTrc77Y+PxyrLglOm0yZ/6jkr4QEQslfVTS9bYXSrpZ0saIOEfSxuo+gCliwvBHxN6IeLy6fVDSDknzJC2VtK5abZ2kKzvVJID2e1vP+W0vkPRhSZslzYmI49+z9LxGnxYAmCImHX7bsyTdL+mmiDgwthYRodHXA8bbbsj2sO3hIzrUUrMA2mdS4bfdp9Hg3x0RD1SL99meW9XnShoZb9uIWBMRgxEx2Kf+dvQMoA0mDL9tS7pD0o6I+MqY0npJK6rbKyQ91P72AHTKZD7Se6Gk5ZK2295aLVsp6RZJf2/7Gkk/kXRVZ1pES84/t1j+8zPuaunhv/alzxfr73ri0ZYeH50zYfgj4hFJblC+uL3tAOgW3uEHJEX4gaQIP5AU4QeSIvxAUoQfSIqv7j4BTFv4gYa1oXtbe+/VwrXXF+sL7vr3lh4f9eHMDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJMc5/AnjmD05tWLti5oGGtck4818Ol1eIcb+9DVMAZ34gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIpx/ing9SsWF+sbr7itUJ3Z3mZwwuDMDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJTTjOb3u+pDslzZEUktZExGrbqyRdK+mFatWVEbGhU41m9j8XTivW3zO9+bH8uw+eUaz3HSh/np9P809dk3mTz1FJX4iIx22fLOkx2w9Xtdsj4sudaw9Ap0wY/ojYK2lvdfug7R2S5nW6MQCd9bae89teIOnDkjZXi26wvc32WtvjfpeU7SHbw7aHj+hQS80CaJ9Jh9/2LEn3S7opIg5I+rqksyUt0uiVwbhvMI+INRExGBGDfepvQ8sA2mFS4bfdp9Hg3x0RD0hSROyLiDci4pikb0gqf/oEQE+ZMPy2LekOSTsi4itjls8ds9pnJD3Z/vYAdMpkXu2/UNJySdttb62WrZS0zPYijY727JJ0XUc6REv+8qWFxfqjv7WgWI+929vYDXrJZF7tf0SSxykxpg9MYbzDD0iK8ANJEX4gKcIPJEX4gaQIP5CUo4tTLJ/i2XGBL+7a/oBsNsdGHYj94w3NvwVnfiApwg8kRfiBpAg/kBThB5Ii/EBShB9Iqqvj/LZfkPSTMYtOl/Ri1xp4e3q1t17tS6K3ZrWzt7Mi4t2TWbGr4X/Lzu3hiBisrYGCXu2tV/uS6K1ZdfXGZT+QFOEHkqo7/Gtq3n9Jr/bWq31J9NasWnqr9Tk/gPrUfeYHUJNawm/7Utv/YftZ2zfX0UMjtnfZ3m57q+3hmntZa3vE9pNjls22/bDtndXvcadJq6m3Vbb3VMduq+3La+ptvu1/tv207ads31gtr/XYFfqq5bh1/bLf9jRJ/ynpEkm7JW2RtCwinu5qIw3Y3iVpMCJqHxO2/ZuSXpF0Z0ScVy27VdL+iLil+o/z1Ij4kx7pbZWkV+qeubmaUGbu2JmlJV0p6XdV47Er9HWVajhudZz5F0t6NiKei4jDku6VtLSGPnpeRGyStP9Ni5dKWlfdXqfRP56ua9BbT4iIvRHxeHX7oKTjM0vXeuwKfdWijvDPk/TTMfd3q7em/A5JP7T9mO2hupsZx5xq2nRJel7SnDqbGceEMzd305tmlu6ZY9fMjNftxgt+b3VRRPy6pMskXV9d3vakGH3O1kvDNZOaublbxplZ+ufqPHbNznjdbnWEf4+k+WPun1kt6wkRsaf6PSLpQfXe7MP7jk+SWv0eqbmfn+ulmZvHm1laPXDsemnG6zrCv0XSObbfa3uGpKslra+hj7ewPVC9ECPbA5I+qd6bfXi9pBXV7RWSHqqxl1/QKzM3N5pZWjUfu56b8Toiuv4j6XKNvuL/Y0l/WkcPDfp6n6Qnqp+n6u5N0j0avQw8otHXRq6RdJqkjZJ2SvonSbN7qLe7JG2XtE2jQZtbU28XafSSfpukrdXP5XUfu0JftRw33uEHJMULfkBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkvp/uK0ZUt56JeQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%local\n",
    "%matplotlib inline\n",
    "plot_test_example(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label:2\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADfFJREFUeJzt3X+MHHd5x/HPY+dsB8ckdu1eD8eN42A3OKniwMmQNm2JQmiwEA5Sm2K19IICBpW0RbIEkanUIH4oqkhSqiKQIRZOlR+E/CBGpBDnAIXQk+NzMLYTAzbpUexefLF81E5b7LvLwx87Rpfk5rvr3dmZPT/vl3S63Xl2Zh6t/bnZ3e/sfM3dBSCeGVU3AKAahB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBnlbmzWTbb52humbsEQvmV/lcn/YQ18tiWwm9m10r6nKSZkr7s7remHj9Hc/Vmu7qVXQJI2O79DT+26Zf9ZjZT0uclvUPSSknrzGxls9sDUK5W3vOvlnTA3Z9z95OS7pO0tpi2ALRbK+FfLOkXk+4fzJa9jJmtN7NBMxsc04kWdgegSG3/tN/dN7l7r7v3dml2u3cHoEGthP+QpCWT7p+fLQMwDbQS/h2SlpvZhWY2S9J7JG0tpi0A7db0UJ+7j5vZTZK+rdpQ32Z3f6awzgC0VUvj/O7+qKRHC+oFQIk4vRcIivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaBKvXQ3mjP0qSuS9Yk5nltbdMkLyXUHLnuwqZ5Oueg770vW5z11dm6t+1/+o6V9ozUc+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMb5O8DoN5cn63tX/Wvb9j2Wf4pAQ3581ZeT9bt7e3Jr92/7k+S6E/v2N9UTGsORHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCammc38yGJB2XNCFp3N17i2jqTFNvHP8Hq+5r276/+MtlyfrtA9ck60svSF8P4LGVDyXrfzlvOLf26RsWJtdd9jHG+dupiJN8rnL3IwVsB0CJeNkPBNVq+F3SY2a208zWF9EQgHK0+rL/Snc/ZGa/LWmbmf3Y3Z+Y/IDsj8J6SZqj17S4OwBFaenI7+6Hst8jkh6WtHqKx2xy91537+3S7FZ2B6BATYffzOaa2bxTtyW9XdLeohoD0F6tvOzvlvSwmZ3azj3u/q1CugLQdk2H392fk3RZgb1MW+NXvylZ/85ln6+zha5k9Z9HVyTr3/2LxOkV/z2SXHfF6GCyPmPOnGT9M9t/P1nfuHBPbm18/nhyXbQXQ31AUIQfCIrwA0ERfiAowg8ERfiBoLh0dwFeXDwrWZ9R529svaG8770rPZw28dxPkvVWHPjE5cn6PQtuq7OF/LM6z/8Wx54q8ewDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8xfgvLsGkvU/G/yrZN1GjyXr48NDp9lRcd6/5vFk/ZwZXJ1puuLIDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMc5fgolnf1p1C7mGPn1Fsn7jeZ+ts4X0pb03DL8ltzbv8X3JdSfq7Bmt4cgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0HVHec3s82S3ilpxN0vzZYtkPRVSUslDUm63t1H29cmmvXL96bH8X/w1+lx/HNnpMfxB07MTNZ3fSr/uv9nH3squS7aq5Ej/1ckXfuKZTdL6nf35ZL6s/sAppG64Xf3JyQdfcXitZK2ZLe3SLqu4L4AtFmz7/m73X04u/28pO6C+gFQkpY/8HN3l+R5dTNbb2aDZjY4phOt7g5AQZoN/2Ez65Gk7PdI3gPdfZO797p7b1di0kYA5Wo2/Fsl9WW3+yQ9Ukw7AMpSN/xmdq+kAUm/Z2YHzexGSbdKusbM9kt6W3YfwDRSd5zf3dfllK4uuBe0wZE35n4cI6n+OH49fd97f7K+4uuM5XcqzvADgiL8QFCEHwiK8ANBEX4gKMIPBMWlu88AJ7ddkFsbuPi2Omunh/ouG+hL1t+w4WfJOpff7lwc+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMb5p4Gzli1N1j/5+q/l1ubX+cruzjpXVrvgk+mR+olRrtg+XXHkB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOefBi66/1Cyfvms5v+Gr+v/ULK+4kc7mt42OhtHfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8Iqu44v5ltlvROSSPufmm27BZJH5D0Qvawje7+aLuaPNON9l2RrH+iu96192fnVvqG3pZc8w0fPZCsc939M1cjR/6vSLp2iuV3uPuq7IfgA9NM3fC7+xOSjpbQC4AStfKe/yYz221mm81sfmEdAShFs+H/gqSLJK2SNCwp902pma03s0EzGxxTnQvGAShNU+F398PuPuHuL0n6kqTVicducvded+/tSnwwBaBcTYXfzHom3X23pL3FtAOgLI0M9d0r6a2SFprZQUn/KOmtZrZKkksakvTBNvYIoA3qht/d102x+M429HLGOmvx65L1P/q77cn6OTOaf7s08Ozrk/UVo3xfPyrO8AOCIvxAUIQfCIrwA0ERfiAowg8ExaW7S7Bv45Jk/eu/842Wtn/Vnj/PrfGVXeThyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTHOX4Kd77qjziNau8LRuX/zUm5tfHS0pW3jzMWRHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCYpz/DDDWfW5urevk4hI7ebWJF47k1vxEevo2m50+/2HmooVN9SRJE4vOS9b3b5jV9LYb4ROWW7v4b+tcg+HYsUJ64MgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0HVHec3syWS7pLULcklbXL3z5nZAklflbRU0pCk692dL49X4JsPbK66hVx/8MOpZnivOXL4tcl15y86nqxvf9M9TfXU6Vb+w03J+rKPDhSyn0aO/OOSNrj7SklvkfRhM1sp6WZJ/e6+XFJ/dh/ANFE3/O4+7O5PZ7ePS9onabGktZK2ZA/bIum6djUJoHin9Z7fzJZKulzSdknd7j6clZ5X7W0BgGmi4fCb2TmSHpT0EXd/2cnF7u6qfR4w1XrrzWzQzAbHlD6XG0B5Ggq/mXWpFvy73f2hbPFhM+vJ6j2SRqZa1903uXuvu/d2tXihSgDFqRt+MzNJd0ra5+63TyptldSX3e6T9Ejx7QFoF6u9Yk88wOxKSd+XtEfSqWtEb1Ttff/9kn5X0s9VG+o7mtrWa22Bv9mubrXnaef/v31hst5/6QMldRLL//nJ3NqY51/uvBFrdt+QrP/Prua/btzz5HiyPvvfd+TWtnu/jvnR/O8LT1J3nN/dn5SUt7F4SQbOEJzhBwRF+IGgCD8QFOEHgiL8QFCEHwiKS3eX4Ow//c9k/ZLPpL/C6W38V5p3cfLUjLZ+bfaS778vWff/mtvS9pc98GJ+8ak9LW17vva3VO8EHPmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKi63+cvUtTv8wNlOZ3v83PkB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaDqht/MlpjZd83sWTN7xsz+Plt+i5kdMrNd2c+a9rcLoCiNTAcxLmmDuz9tZvMk7TSzbVntDnf/bPvaA9AudcPv7sOShrPbx81sn6TF7W4MQHud1nt+M1sq6XJJ27NFN5nZbjPbbGbzc9ZZb2aDZjY4phMtNQugOA2H38zOkfSgpI+4+zFJX5B0kaRVqr0yuG2q9dx9k7v3untvl2YX0DKAIjQUfjPrUi34d7v7Q5Lk7ofdfcLdX5L0JUmr29cmgKI18mm/SbpT0j53v33S8p5JD3u3pL3FtwegXRr5tP8PJb1X0h4z25Ut2yhpnZmtkuSShiR9sC0dAmiLRj7tf1LSVNcBf7T4dgCUhTP8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQZm7l7czsxck/XzSooWSjpTWwOnp1N46tS+J3ppVZG8XuPuiRh5YavhftXOzQXfvrayBhE7trVP7kuitWVX1xst+ICjCDwRVdfg3Vbz/lE7trVP7kuitWZX0Vul7fgDVqfrID6AilYTfzK41s5+Y2QEzu7mKHvKY2ZCZ7clmHh6suJfNZjZiZnsnLVtgZtvMbH/2e8pp0irqrSNmbk7MLF3pc9dpM16X/rLfzGZK+qmkayQdlLRD0jp3f7bURnKY2ZCkXnevfEzYzP5Y0ouS7nL3S7Nl/yTpqLvfmv3hnO/uH+uQ3m6R9GLVMzdnE8r0TJ5ZWtJ1km5Qhc9doq/rVcHzVsWRf7WkA+7+nLuflHSfpLUV9NHx3P0JSUdfsXitpC3Z7S2q/ecpXU5vHcHdh9396ez2cUmnZpau9LlL9FWJKsK/WNIvJt0/qM6a8tslPWZmO81sfdXNTKE7mzZdkp6X1F1lM1OoO3NzmV4xs3THPHfNzHhdND7we7Ur3f2Nkt4h6cPZy9uO5LX3bJ00XNPQzM1lmWJm6d+o8rlrdsbrolUR/kOSlky6f362rCO4+6Hs94ikh9V5sw8fPjVJavZ7pOJ+fqOTZm6eamZpdcBz10kzXlcR/h2SlpvZhWY2S9J7JG2toI9XMbO52QcxMrO5kt6uzpt9eKukvux2n6RHKuzlZTpl5ua8maVV8XPXcTNeu3vpP5LWqPaJ/88kfbyKHnL6WibpR9nPM1X3Jule1V4Gjqn22ciNkn5LUr+k/ZIel7Sgg3r7N0l7JO1WLWg9FfV2pWov6XdL2pX9rKn6uUv0Vcnzxhl+QFB84AcERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+IKhfA10jPiPOz+MkAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%local\n",
    "%matplotlib inline\n",
    "plot_test_example(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: Save  the Numpy Dataset as a Petastorm Dataset\n",
    "\n",
    "We have now parsed the MNIST database of handwritten images into numpy arrays `train_images, train_labels, test_images, test_labels`. We can store this to HDFS as a Petastorm dataset by (1) specifying the petastorm schema; (2) convert the numpy data to a spark dataframe; (3) Save the data to petastorm training dataset in the featurestore using  `featurestore.create_training_dataset()` and set `data_format='petastorm'`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Specify Petastorm Schema\n",
    "\n",
    "Petastorm datasets are designed to be easy to read into deep learning frameworks such as Tensorflow and Pytorch. To enable this we need to specify the schema when writing the data to HDFS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "MnistSchema = Unischema('MnistSchema', [\n",
    "    UnischemaField('image', np.uint8, (28, 28,1), NdarrayCodec(), False),\n",
    "    UnischemaField('digit', np.int_, (), ScalarCodec(IntegerType()), False)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create a Spark Dataframe with the data in the numpy arrays\n",
    "\n",
    "Petastorm relies on Spark to write to HDFS so we have to convert the numpy arrays into a spark dataframe that conforms to the Petastorm schema."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_OUTPUT_PATH = hdfs.project_path() + \"mnist/train_petastorm\"\n",
    "TEST_OUTPUT_PATH = hdfs.project_path() + \"mnist/test_petastorm\" \n",
    "mnist_data = {\n",
    "            'train': {\"images\": train_images, \"labels\": train_labels, \"output_path\": TRAIN_OUTPUT_PATH},\n",
    "            'test': {\"images\": test_images, \"labels\": test_labels, \"output_path\": TEST_OUTPUT_PATH}\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First create an RDD of dict-rows with the numpy arrays and then convert the dict-rows into Spark Dataframe Rows using the petastorm utility function `dict_to_spark_row`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_petastorm_dict(i, test=False):\n",
    "    \"\"\"\n",
    "    Returns a train or test example from mnist_data variable as a dict with the Petastorm schema\n",
    "    \n",
    "    Args:\n",
    "        :i: the index of the example\n",
    "        :test: a boolean flag whether to get a test example, otherwise gets a training example\n",
    "    \"\"\"\n",
    "    if test:\n",
    "        return {\n",
    "            MnistSchema.digit.name: mnist_data[\"test\"][\"labels\"][i],\n",
    "            MnistSchema.image.name: mnist_data[\"test\"][\"images\"][i]\n",
    "        }\n",
    "    return {\n",
    "            MnistSchema.digit.name: mnist_data[\"train\"][\"labels\"][i],\n",
    "            MnistSchema.image.name: mnist_data[\"train\"][\"images\"][i]\n",
    "        } "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train_rdd = sc.parallelize(list(range(len(mnist_data[\"train\"][\"images\"]))))\\\n",
    "        .map(lambda i: create_petastorm_dict(i))\\\n",
    "        .map(lambda x: dict_to_spark_row(MnistSchema, x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_test_rdd = sc.parallelize(list(range(len(mnist_data[\"test\"][\"images\"]))))\\\n",
    "        .map(lambda i: create_petastorm_dict(i, test=True))\\\n",
    "        .map(lambda x: dict_to_spark_row(MnistSchema, x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then convert the RDDs to Spark Dataframes using `MnistSchema.as_spark_schema()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = spark.createDataFrame(dataset_train_rdd, MnistSchema.as_spark_schema())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test = spark.createDataFrame(dataset_test_rdd, MnistSchema.as_spark_schema())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note** Using dicts/rdds and then spark.createDataFrame() is just one approach to create a spark dataframe that conforms to the Petastorm schema. If you don't have the data in numpy arrays but rather saved on disk you might use another approach, the important thing is that you define a petastorm schema and that your dataframe's spark schema is equal to `petastorm_schema.as_spark_schema()`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Save the Spark Dataframes as Training Dataset in the Petastorm Format in the Hopsworks Feature Store"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All parameters available for the petastorm method `materialize_dataset` listed here: https://petastorm.readthedocs.io/en/latest/api.html can be provided as arguments in a dict to `featurestore.create_training_dataset(petastorm_args=dict)`. The only required argument is \"schema\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "petastorm_args = {\n",
    "    \"schema\": MnistSchema\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save Dataset with Training Data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "computing descriptive statistics for : MNIST_train_petastorm"
     ]
    }
   ],
   "source": [
    "featurestore.create_training_dataset(df_train, \"MNIST_train_petastorm\", data_format=\"petastorm\", \n",
    "                                     petastorm_args=petastorm_args,\n",
    "                                     description=\"MNIST Digit Database Training Dataset of 60000 images stored in the Petastorm Format\",\n",
    "                                     #this type of statistics do not make sense on image data\n",
    "                                     feature_correlation=False, feature_histograms=False, cluster_analysis=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "petastorm_args = {\n",
    "    \"schema\": MnistSchema\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save Dataset with Test Data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "computing descriptive statistics for : MNIST_test_petastorm"
     ]
    }
   ],
   "source": [
    "featurestore.create_training_dataset(df_test, \"MNIST_test_petastorm\", data_format=\"petastorm\", \n",
    "                                     petastorm_args=petastorm_args,\n",
    "                                     description=\"MNIST Digit Database Test Dataset of 10000 images stored in the Petastorm Format\",\n",
    "                                     #this type of statistics do not make sense on image data\n",
    "                                     feature_correlation=False, feature_histograms=False, cluster_analysis=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5: Use the Training Dataset\n",
    "\n",
    "Once the datasets are saved in the feature store you can find them in the Feature Registry inside your project and easily reuse it later. By saving the dataset as a petastorm dataset you can use the same data in several different deep learning frameworks, such as PyTorch and Tensorflow. We have two example notebooks that demonstrate this here:\n",
    "\n",
    "- [Tensorflow Example](PetastormMNIST_Tensorflow.ipynb)\n",
    "\n",
    "- [PyTorch Example](PetastormMNIST_PyTorch.ipynb)\n",
    "\n",
    "![Petastorm 5](./../images/petastorm5.png \"Petastorm 5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PySpark",
   "language": "",
   "name": "pysparkkernel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "python",
    "version": 2
   },
   "mimetype": "text/x-python",
   "name": "pyspark",
   "pygments_lexer": "python2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}